{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/NeuromatchAcademy/course-content/blob/master/tutorials/W2D5_ReinforcementLearning/W2D5_Tutorial1.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "# Neuromatch Academy: Week 3, Day 4, Tutorial 1\n",
    "# Learning to Predict\n",
    "\n",
    "__Content creators:__ Marcelo Mattar and Eric DeWitt with help from Matt Krause\n",
    "\n",
    "__Content reviewers:__ Byron Galbraith and Michael Waskom\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "\n",
    "# Tutorial objectives\n",
    "  \n",
    "In this tutorial, we will learn how to estimate state-value functions in a classical conditioning paradigm using Temporal Difference (TD) learning and examine TD-errors at the presentation of the conditioned and unconditioned stimulus (CS and US) under different CS-US contingencies. These exercises will provide you with an understanding of both how reward prediction errors (RPEs) behave in classical conditioning and what we should expect to see if Dopamine represents a \"canonical\" model-free RPE. \n",
    "\n",
    "At the end of this tutorial:    \n",
    "* You will learn to use the standard tapped delay line conditioning model\n",
    "* You will understand how RPEs move to CS\n",
    "* You will understand how variability in reward size effects RPEs\n",
    "* You will understand how differences in US-CS timing effect RPEs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "code",
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:18:44.583293Z",
     "iopub.status.busy": "2021-05-25T01:18:44.582756Z",
     "iopub.status.idle": "2021-05-25T01:18:44.890729Z",
     "shell.execute_reply": "2021-05-25T01:18:44.889516Z"
    }
   },
   "outputs": [],
   "source": [
    "# Imports\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:18:44.895184Z",
     "iopub.status.busy": "2021-05-25T01:18:44.894661Z",
     "iopub.status.idle": "2021-05-25T01:18:44.975972Z",
     "shell.execute_reply": "2021-05-25T01:18:44.974934Z"
    }
   },
   "outputs": [],
   "source": [
    "#@title Figure settings\n",
    "import ipywidgets as widgets       # interactive display\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/nma.mplstyle\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:18:44.998390Z",
     "iopub.status.busy": "2021-05-25T01:18:44.992102Z",
     "iopub.status.idle": "2021-05-25T01:18:45.001345Z",
     "shell.execute_reply": "2021-05-25T01:18:45.000908Z"
    }
   },
   "outputs": [],
   "source": [
    "# @title Helper functions\n",
    "from matplotlib import ticker\n",
    "\n",
    "def plot_value_function(V, ax=None, show=True):\n",
    "  \"\"\"Plot V(s), the value function\"\"\"\n",
    "  if not ax:\n",
    "    fig, ax = plt.subplots()\n",
    "\n",
    "  ax.stem(V, use_line_collection=True)\n",
    "  ax.set_ylabel('Value')\n",
    "  ax.set_xlabel('State')\n",
    "  ax.set_title(\"Value function: $V(s)$\")\n",
    "\n",
    "  if show:\n",
    "    plt.show()\n",
    "\n",
    "def plot_tde_trace(TDE, ax=None, show=True, skip=400):\n",
    "  \"\"\"Plot the TD Error across trials\"\"\"\n",
    "  if not ax:\n",
    "    fig, ax = plt.subplots()\n",
    "\n",
    "  indx = np.arange(0, TDE.shape[1], skip)\n",
    "  im = ax.imshow(TDE[:,indx])\n",
    "  positions = ax.get_xticks()\n",
    "  # Avoid warning when setting string tick labels\n",
    "  ax.xaxis.set_major_locator(ticker.FixedLocator(positions))\n",
    "  ax.set_xticklabels([f\"{int(skip * x)}\" for x in positions])\n",
    "  ax.set_title('TD-error over learning')\n",
    "  ax.set_ylabel('State')\n",
    "  ax.set_xlabel('Iterations')\n",
    "  ax.figure.colorbar(im)\n",
    "  if show:\n",
    "    plt.show()\n",
    "\n",
    "def learning_summary_plot(V, TDE):\n",
    "  \"\"\"Summary plot for Ex1\"\"\"\n",
    "  fig, (ax1, ax2) = plt.subplots(nrows = 2, gridspec_kw={'height_ratios': [1, 2]})\n",
    "\n",
    "  plot_value_function(V, ax=ax1, show=False)\n",
    "  plot_tde_trace(TDE, ax=ax2, show=False)\n",
    "  plt.tight_layout()\n",
    "\n",
    "def reward_guesser_title_hint(r1, r2):\n",
    "  \"\"\"\"Provide a mildly obfuscated hint for a demo.\"\"\"\n",
    "  if (r1==14 and r2==6) or (r1==6 and r2==14):\n",
    "    return \"Technically correct...(the best kind of correct)\"\n",
    "\n",
    "  if  ~(~(r1+r2) ^ 11) - 1 == (6 | 24): # Don't spoil the fun :-)\n",
    "    return \"Congratulations! You solved it!\"\n",
    "\n",
    "  return \"Keep trying....\"\n",
    "\n",
    "#@title Default title text\n",
    "class ClassicalConditioning:\n",
    "\n",
    "    def __init__(self, n_steps, reward_magnitude, reward_time):\n",
    "\n",
    "        # Task variables\n",
    "        self.n_steps = n_steps\n",
    "        self.n_actions = 0\n",
    "        self.cs_time = int(n_steps/4) - 1\n",
    "\n",
    "        # Reward variables\n",
    "        self.reward_state = [0,0]\n",
    "        self.reward_magnitude = None\n",
    "        self.reward_probability = None\n",
    "        self.reward_time = None\n",
    "\n",
    "        self.set_reward(reward_magnitude, reward_time)\n",
    "\n",
    "        # Time step at which the conditioned stimulus is presented\n",
    "\n",
    "        # Create a state dictionary\n",
    "        self._create_state_dictionary()\n",
    "\n",
    "    def set_reward(self, reward_magnitude, reward_time):\n",
    "\n",
    "        \"\"\"\n",
    "        Determine reward state and magnitude of reward\n",
    "        \"\"\"\n",
    "        if reward_time >= self.n_steps - self.cs_time:\n",
    "            self.reward_magnitude = 0\n",
    "\n",
    "        else:\n",
    "            self.reward_magnitude = reward_magnitude\n",
    "            self.reward_state = [1, reward_time]\n",
    "\n",
    "    def get_outcome(self, current_state):\n",
    "\n",
    "        \"\"\"\n",
    "        Determine next state and reward\n",
    "        \"\"\"\n",
    "        # Update state\n",
    "        if current_state < self.n_steps - 1:\n",
    "            next_state = current_state + 1\n",
    "        else:\n",
    "            next_state = 0\n",
    "\n",
    "        # Check for reward\n",
    "        if self.reward_state == self.state_dict[current_state]:\n",
    "            reward = self.reward_magnitude\n",
    "        else:\n",
    "            reward = 0\n",
    "\n",
    "        return next_state, reward\n",
    "\n",
    "    def _create_state_dictionary(self):\n",
    "\n",
    "        \"\"\"\n",
    "        This dictionary maps number of time steps/ state identities\n",
    "        in each episode to some useful state attributes:\n",
    "\n",
    "        state      - 0 1 2 3 4 5 (cs) 6 7 8 9 10 11 12 ...\n",
    "        is_delay   - 0 0 0 0 0 0 (cs) 1 1 1 1  1  1  1 ...\n",
    "        t_in_delay - 0 0 0 0 0 0 (cs) 1 2 3 4  5  6  7 ...\n",
    "        \"\"\"\n",
    "        d = 0\n",
    "\n",
    "        self.state_dict = {}\n",
    "        for s in range(self.n_steps):\n",
    "            if s <= self.cs_time:\n",
    "                self.state_dict[s] = [0,0]\n",
    "            else:\n",
    "                d += 1 # Time in delay\n",
    "                self.state_dict[s] = [1,d]\n",
    "\n",
    "class MultiRewardCC(ClassicalConditioning):\n",
    "  \"\"\"Classical conditioning paradigm, except that one randomly selected reward,\n",
    "    magnitude, from a list, is delivered of a single fixed reward.\"\"\"\n",
    "  def __init__(self, n_steps, reward_magnitudes, reward_time=None):\n",
    "    \"\"\"\"Build a multi-reward classical conditioning environment\n",
    "      Args:\n",
    "        - nsteps: Maximum number of steps\n",
    "        - reward_magnitudes: LIST of possible reward magnitudes.\n",
    "        - reward_time: Single fixed reward time\n",
    "      Uses numpy global random state.\n",
    "      \"\"\"\n",
    "    super().__init__(n_steps, 1, reward_time)\n",
    "    self.reward_magnitudes = reward_magnitudes\n",
    "\n",
    "  def get_outcome(self, current_state):\n",
    "    next_state, reward = super().get_outcome(current_state)\n",
    "    if reward:\n",
    "      reward=np.random.choice(self.reward_magnitudes)\n",
    "    return next_state, reward\n",
    "\n",
    "\n",
    "class ProbabilisticCC(ClassicalConditioning):\n",
    "  \"\"\"Classical conditioning paradigm, except that rewards are stochastically omitted.\"\"\"\n",
    "  def __init__(self, n_steps, reward_magnitude, reward_time=None, p_reward=0.75):\n",
    "    \"\"\"\"Build a multi-reward classical conditioning environment\n",
    "      Args:\n",
    "        - nsteps: Maximum number of steps\n",
    "        - reward_magnitudes: Reward magnitudes.\n",
    "        - reward_time: Single fixed reward time.\n",
    "        - p_reward: probability that reward is actually delivered in rewarding state\n",
    "      Uses numpy global random state.\n",
    "      \"\"\"\n",
    "    super().__init__(n_steps, reward_magnitude, reward_time)\n",
    "    self.p_reward = p_reward\n",
    "\n",
    "  def get_outcome(self, current_state):\n",
    "    next_state, reward = super().get_outcome(current_state)\n",
    "    if reward:\n",
    "      reward*= int(np.random.uniform(size=1)[0] < self.p_reward)\n",
    "    return next_state, reward"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Section 1: TD-learning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 519
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:18:45.009080Z",
     "iopub.status.busy": "2021-05-25T01:18:45.008532Z",
     "iopub.status.idle": "2021-05-25T01:18:45.053046Z",
     "shell.execute_reply": "2021-05-25T01:18:45.053649Z"
    },
    "outputId": "6360c4ae-7abf-4de6-efad-c496484c4104"
   },
   "outputs": [],
   "source": [
    "#@title Video 1: Introduction\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"YoNbc9M92YY\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtu.be/\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "__Environment:__\n",
    "\n",
    "- The agent experiences the environment in episodes or trials. \n",
    "- Episodes terminate by transitioning to the inter-trial-interval (ITI) state and they are initiated from the ITI state as well. We clamp the value of the terminal/ITI states to zero. \n",
    "- The classical conditioning environment is composed of a sequence of states that the agent deterministically transitions through. Starting at State 0, the agent moves to State 1 in the first step, from State 1 to State 2 in the second, and so on.  These states represent time in the tapped delay line representation\n",
    "- Within each episode, the agent is presented a CS and US (reward). \n",
    "- The CS is always presented at 1/4 of the total duration of the trial. The US (reward) is then delivered after the CS. The interval between the CS and US is specified by `reward_time`.\n",
    "- The agent's goal is to learn to predict expected rewards from each state in the trial. \n",
    "\n",
    "\n",
    "**General concepts**\n",
    "\n",
    "* Return $G_{t}$: future cumulative reward, which can be written in arecursive form\n",
    "\\begin{align}\n",
    "G_{t} &= \\sum \\limits_{k = 0}^{\\infty} \\gamma^{k} r_{t+k+1} \\\\\n",
    "&= r_{t+1} + \\gamma G_{t+1}\n",
    "\\end{align}\n",
    "where $\\gamma$ is discount factor that controls the importance of future rewards, and $\\gamma \\in [0, 1]$. $\\gamma$ may also be interpreted as probability of continuing the trajectory.\n",
    "* Value funtion $V_{\\pi}(s_t=s)$: expecation of the return\n",
    "\\begin{align}\n",
    "V_{\\pi}(s_t=s) &= \\mathbb{E} [ G_{t}\\; | \\; s_t=s, a_{t:\\infty}\\sim\\pi] \\\\\n",
    "& = \\mathbb{E} [ r_{t+1} + \\gamma G_{t+1}\\; | \\; s_t=s, a_{t:\\infty}\\sim\\pi]\n",
    "\\end{align}\n",
    "With an assumption of **Markov process**, we thus have:\n",
    "\\begin{align}\n",
    "V_{\\pi}(s_t=s) &= \\mathbb{E} [ r_{t+1} + \\gamma V_{\\pi}(s_{t+1})\\; | \\; s_t=s, a_{t:\\infty}\\sim\\pi] \\\\\n",
    "&= \\sum_a \\pi(a|s) \\sum_{r, s'}p(s', r)(r + V_{\\pi}(s_{t+1}=s'))\n",
    "\\end{align}\n",
    "\n",
    "**Temporal difference (TD) learning**\n",
    "\n",
    "* With a Markovian assumption, we can use $V(s_{t+1})$ as an imperfect proxy for the true value $G_{t+1}$ (Monte Carlo bootstrapping), and thus obtain the generalised equation to calculate TD-error:\n",
    "\\begin{align}\n",
    "\\delta_{t} = r_{t+1} + \\gamma V(s_{t+1}) - V(s_{t})\n",
    "\\end{align}\n",
    "\n",
    "* Value updated by using the learning rate constant $\\alpha$:\n",
    "\\begin{align}\n",
    "V(s_{t}) \\leftarrow V(s_{t}) + \\alpha \\delta_{t}\n",
    "\\end{align}\n",
    "\n",
    "  (Reference: https://web.stanford.edu/group/pdplab/pdphandbook/handbookch10.html)\n",
    "\n",
    "\n",
    "\n",
    "__Definitions:__\n",
    "\n",
    "* TD-error:\n",
    "\\begin{align}\n",
    "\\delta_{t} = r_{t+1} + \\gamma V(s_{t+1}) - V(s_{t})\n",
    "\\end{align}\n",
    "\n",
    "* Value updates:\n",
    "\\begin{align}\n",
    "V(s_{t}) \\leftarrow V(s_{t}) + \\alpha \\delta_{t}\n",
    "\\end{align}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Exercise 1: TD-learning with guaranteed rewards\n",
    "   \n",
    "Implement TD-learning to estimate the state-value function in the classical-conditioning world with guaranteed rewards, with a fixed magnitude, at a fixed delay after the conditioned stimulus, CS. Save TD-errors over learning (i.e., over trials) so we can visualize them afterwards. \n",
    "\n",
    "In order to simulate the effect of the CS, you should only update $V(s_{t})$ during the delay period after CS. This period is indicated by the boolean variable `is_delay`. This can be implemented by multiplying the expression for updating the value function by `is_delay`.\n",
    "\n",
    "Use the provided code to estimate the value function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:18:45.061126Z",
     "iopub.status.busy": "2021-05-25T01:18:45.059736Z",
     "iopub.status.idle": "2021-05-25T01:18:45.061767Z",
     "shell.execute_reply": "2021-05-25T01:18:45.062221Z"
    }
   },
   "outputs": [],
   "source": [
    "def td_learner(env, n_trials, gamma=0.98, alpha=0.001):\n",
    "  \"\"\" Temporal Difference learning\n",
    "\n",
    "  Args:\n",
    "    env (object): the environment to be learned\n",
    "    n_trials (int): the number of trials to run\n",
    "    gamma (float): temporal discount factor\n",
    "    alpha (float): learning rate\n",
    "\n",
    "  Returns:\n",
    "    ndarray, ndarray: the value function and temporal difference error arrays\n",
    "  \"\"\"\n",
    "  V = np.zeros(env.n_steps) # Array to store values over states (time)\n",
    "  TDE = np.zeros((env.n_steps, n_trials)) # Array to store TD errors\n",
    "\n",
    "  for n in range(n_trials):\n",
    "    state = 0 # Initial state\n",
    "    for t in range(env.n_steps):\n",
    "      # Get next state and next reward\n",
    "      next_state, reward = env.get_outcome(state)\n",
    "      # Is the current state in the delay period (after CS)?\n",
    "      is_delay = env.state_dict[state][0]\n",
    "\n",
    "      ########################################################################\n",
    "      ## TODO for students: implement TD error and value function update\n",
    "      # Fill out function and remove\n",
    "      raise NotImplementedError(\"Student excercise: implement TD error and value function update\")\n",
    "      #################################################################################\n",
    "      # Write an expression to compute the TD-error\n",
    "      TDE[state, n] = ...\n",
    "\n",
    "      # Write an expression to update the value function\n",
    "      V[state] += ...\n",
    "\n",
    "      # Update state\n",
    "      state = next_state\n",
    "\n",
    "  return V, TDE\n",
    "\n",
    "\n",
    "# Uncomment once the td_learner function is complete\n",
    "# env = ClassicalConditioning(n_steps=40, reward_magnitude=10, reward_time=10)\n",
    "# V, TDE = td_learner(env, n_trials=20000)\n",
    "# learning_summary_plot(V, TDE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 430
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:18:45.121324Z",
     "iopub.status.busy": "2021-05-25T01:18:45.084207Z",
     "iopub.status.idle": "2021-05-25T01:18:47.233153Z",
     "shell.execute_reply": "2021-05-25T01:18:47.232642Z"
    },
    "outputId": "7e947f90-bf43-465c-8090-78b4e2eb41a6"
   },
   "outputs": [],
   "source": [
    "#to_remove solution\n",
    "def td_learner(env, n_trials, gamma=0.98, alpha=0.001):\n",
    "  \"\"\" Temporal Difference learning\n",
    "\n",
    "  Args:\n",
    "    env (object): the environment to be learned\n",
    "    n_trials (int): the number of trials to run\n",
    "    gamma (float): temporal discount factor\n",
    "    alpha (float): learning rate\n",
    "\n",
    "  Returns:\n",
    "    ndarray, ndarray: the value function and temporal difference error arrays\n",
    "  \"\"\"\n",
    "  V = np.zeros(env.n_steps) # Array to store values over states (time)\n",
    "  TDE = np.zeros((env.n_steps, n_trials)) # Array to store TD errors\n",
    "\n",
    "  for n in range(n_trials):\n",
    "    state = 0 # Initial state\n",
    "    for t in range(env.n_steps):\n",
    "      # Get next state and next reward\n",
    "      next_state, reward = env.get_outcome(state)\n",
    "      # Is the current state in the delay period (after CS)?\n",
    "      is_delay = env.state_dict[state][0]\n",
    "\n",
    "      # Write an expression to compute the TD-error\n",
    "      TDE[state, n] = (reward + gamma * V[next_state] - V[state])\n",
    "\n",
    "      # Write an expression to update the value function\n",
    "      V[state] += alpha * TDE[state, n] * is_delay\n",
    "\n",
    "      # Update state\n",
    "      state = next_state\n",
    "\n",
    "  return V, TDE\n",
    "\n",
    "\n",
    "env = ClassicalConditioning(n_steps=40, reward_magnitude=10, reward_time=10)\n",
    "V, TDE = td_learner(env, n_trials=20000)\n",
    "with plt.xkcd():\n",
    "  learning_summary_plot(V, TDE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Interactive Demo 1: US to CS Transfer \n",
    "\n",
    "During classical conditioning, the subject's behavioral response (e.g., salivating) transfers from the unconditioned stimulus (US; like the smell of tasty food) to the conditioned stimulus (CS; like Pavlov ringing his bell) that predicts it. Reward prediction errors play an important role in this process by adjusting the value of states according to their expected, discounted return.\n",
    "\n",
    "Use the widget below to examine how reward prediction errors change over time. Before training (orange line), only the reward state has high reward prediction error. As training progresses (blue line, slider), the reward prediction errors shift to the conditioned stimulus, where they end up when the trial is complete (green line). \n",
    "\n",
    "Dopamine neurons, which are thought to carry reward prediction errors _in vivo_, show exactly the same behavior!\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 462,
     "referenced_widgets": [
      "a758486cc37d420985a213f47e925762",
      "9f001ba3b42748f4a4d1e489aee00550",
      "116636a168f540d1ae597a14d9e6dc2e",
      "093b1cb07f62460ba2385102b80e214a",
      "ebf2fbbb55154dc4ad525313275f8056",
      "60c4e4d386434c73b7bf0448c8c3e936",
      "cc7a2dea483947eeb74392813a08819f"
     ]
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:18:47.259207Z",
     "iopub.status.busy": "2021-05-25T01:18:47.257674Z",
     "iopub.status.idle": "2021-05-25T01:18:47.519615Z",
     "shell.execute_reply": "2021-05-25T01:18:47.519146Z"
    },
    "outputId": "fdd401d1-8fb6-4a71-d957-4e826c287138"
   },
   "outputs": [],
   "source": [
    "#@title\n",
    "\n",
    "#@markdown Make sure you execute this cell to enable the widget!\n",
    "\n",
    "n_trials = 20000\n",
    "\n",
    "@widgets.interact\n",
    "def plot_tde_by_trial(trial = widgets.IntSlider(value=5000, min=0, max=n_trials-1 , step=1, description=\"Trial #\")):\n",
    "  if 'TDE' not in globals():\n",
    "    print(\"Complete Exercise 1 to enable this interactive demo!\")\n",
    "  else:\n",
    "\n",
    "    fig, ax = plt.subplots()\n",
    "    ax.axhline(0, color='k') # Use this + basefmt=' ' to keep the legend clean.\n",
    "    ax.stem(TDE[:, 0], linefmt='C1-', markerfmt='C1d', basefmt=' ',\n",
    "            label=\"Before Learning (Trial 0)\",\n",
    "            use_line_collection=True)\n",
    "    ax.stem(TDE[:, -1], linefmt='C2-', markerfmt='C2s', basefmt=' ',\n",
    "            label=\"After Learning (Trial $\\infty$)\",\n",
    "            use_line_collection=True)\n",
    "    ax.stem(TDE[:, trial], linefmt='C0-', markerfmt='C0o', basefmt=' ',\n",
    "            label=f\"Trial {trial}\",\n",
    "            use_line_collection=True)\n",
    "\n",
    "    ax.set_xlabel(\"State in trial\")\n",
    "    ax.set_ylabel(\"TD Error\")\n",
    "    ax.set_title(\"Temporal Difference Error by Trial\")\n",
    "    ax.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Interactive Demo 2: Learning Rates and Discount Factors\n",
    "\n",
    "Our TD-learning agent has two parameters that control how it learns: $\\alpha$, the learning rate, and $\\gamma$, the discount factor. In Exercise 1, we set these parameters to $\\alpha=0.001$ and $\\gamma=0.98$ for you. Here, you'll investigate how changing these parameters alters the model that TD-learning learns.\n",
    "\n",
    "Before enabling the interactive demo below, take a moment to think about the functions of these two parameters. $\\alpha$ controls the size of the Value function updates produced by each TD-error. In our simple, deterministic world, will this affect the final model we learn? Is a larger $\\alpha$ necessarily better in more complex, realistic environments?\n",
    "\n",
    "The discount rate $\\gamma$ applies an exponentially-decaying weight to returns occuring in the future, rather than the present timestep. How does this affect the model we learn? What happens when $\\gamma=0$ or $\\gamma \\geq 1$?\n",
    "\n",
    "Use the widget to test your hypotheses.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 495,
     "referenced_widgets": [
      "d508b140df9244aa8496bee5e443a61e",
      "17dccada1bfe40c38775ea3573c73d55",
      "15b797d9aced416aa50479090b69d7a6",
      "fed5cacebcb94223b82273fd45165c40",
      "a4d5eb402e8a423b84988fb9f14eeeb9",
      "78a57d148f14489ba08134b141d7d3bb",
      "c2f50af9ffc04a62b5af6ac4191a5b08",
      "225fd0fc568446968ec1e11c325a00e7",
      "90e2617cf1d7455ba5c8e703ba5027aa",
      "d674b87c34004447867ba32fa7501242"
     ]
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:18:47.551583Z",
     "iopub.status.busy": "2021-05-25T01:18:47.549881Z",
     "iopub.status.idle": "2021-05-25T01:18:49.471716Z",
     "shell.execute_reply": "2021-05-25T01:18:49.472161Z"
    },
    "outputId": "32ca352d-1141-4e07-eb06-5216639a61d1"
   },
   "outputs": [],
   "source": [
    "#@title\n",
    "\n",
    "#@markdown Make sure you execute this cell to enable the widget!\n",
    "\n",
    "@widgets.interact\n",
    "def plot_summary_alpha_gamma(alpha = widgets.FloatSlider(value=0.0001, min=0.0001, max=0.1, step=0.0001, readout_format='.4f', description=\"alpha\"),\n",
    "                             gamma = widgets.FloatSlider(value=0.980, min=0, max=1.1, step=0.010, description=\"gamma\")):\n",
    "  env = ClassicalConditioning(n_steps=40, reward_magnitude=10, reward_time=10)\n",
    "  try:\n",
    "    V_params, TDE_params = td_learner(env, n_trials=20000, gamma=gamma, alpha=alpha)\n",
    "  except NotImplementedError:\n",
    "    print(\"Finish Exercise 1 to enable this interactive demo\")\n",
    "\n",
    "  learning_summary_plot(V_params,TDE_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:18:49.479693Z",
     "iopub.status.busy": "2021-05-25T01:18:49.478753Z",
     "iopub.status.idle": "2021-05-25T01:18:49.481388Z",
     "shell.execute_reply": "2021-05-25T01:18:49.481843Z"
    }
   },
   "outputs": [],
   "source": [
    "#to_remove explanation\n",
    "\"\"\"\n",
    "Alpha determines how fast the model learns. In the simple, deterministic world\n",
    "we're using here, this allows the moodel to quickly converge onto the \"true\"\n",
    "model that heavily values the conditioned stimulus. In more complex environments,\n",
    "however, excessively large values of alpha can slow, or even prevent, learning,\n",
    "as we'll see later.\n",
    "\n",
    "Gamma effectively controls how much the model cares about the future: larger values of\n",
    "gamma cause the model to weight future rewards nearly as much as present ones. At gamma=1,\n",
    "the model weights all rewards, regardless of when they occur, equally and when greater than one, it\n",
    "starts to *prefer* rewards in the future, rather than the present (this is rarely good).\n",
    "When gamma=0, however, the model becomes greedy and only considers rewards that\n",
    "can be obtained immediately.\n",
    " \"\"\";"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Section 2: TD-learning with varying reward magnitudes\n",
    "\n",
    "In the previous exercise, the environment was as simple as possible. On every trial, the CS predicted the same reward, at the same time, with 100% certainty. In the next few exercises, we will make the environment more progressively more complicated and examine the TD-learner's behavior. \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Interactive Demo 3: Match the Value Functions\n",
    "\n",
    "First, will replace the environment with one that dispenses one of several rewards, chosen at random. Shown below is the final value function $V$ for a TD learner that was trained in an enviroment where the CS predicted a reward of 6 or 14 units; both rewards were equally likely). \n",
    "\n",
    "Can you find another pair of rewards that cause the agent to learn the same value function? Assume each reward will be dispensed 50% of the time. \n",
    "\n",
    "Hints:\n",
    "* Carefully consider the definition of the value function $V$. This can be solved analytically.\n",
    "* There is no need to change $\\alpha$ or $\\gamma$. \n",
    "* Due to the randomness, there may be a small amount of variation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 495,
     "referenced_widgets": [
      "bc86c3d260304a7d869b4883ee15958e",
      "4c2e366503f24403bf3b89cf18d7bcd7",
      "dcef477f4c7f4d63b0b592da1b8910c0",
      "6aebe27048ad4b85ba4dbef72519b4d5",
      "65c5d35d9bdd4fe6b21a24472e19ef47",
      "2eef5df7eea844a89fdb6521bab5c26a",
      "001cad6c97c148deac60bf3a055a5aaa",
      "d4a23a231af7443ea1a145a2928a95bf",
      "5754ec764b1a4c4cb7fd574af344b352",
      "806fa794443445eaa4447bcb7e3e8c62"
     ]
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:18:49.490212Z",
     "iopub.status.busy": "2021-05-25T01:18:49.489660Z",
     "iopub.status.idle": "2021-05-25T01:18:54.165333Z",
     "shell.execute_reply": "2021-05-25T01:18:54.164809Z"
    },
    "outputId": "160612ad-b958-4bdf-b539-686b56386b95"
   },
   "outputs": [],
   "source": [
    "#@title\n",
    "\n",
    "#@markdown Make sure you execute this cell to enable the widget!\n",
    "\n",
    "n_trials = 20000\n",
    "np.random.seed(2020)\n",
    "rng_state = np.random.get_state()\n",
    "env = MultiRewardCC(40, [6, 14], reward_time=10)\n",
    "V_multi, TDE_multi = td_learner(env, n_trials, gamma=0.98, alpha=0.001)\n",
    "\n",
    "@widgets.interact\n",
    "def reward_guesser_interaction(r1 = widgets.IntText(value=0, min=0, max=50, description=\"Reward 1\"),\n",
    "                               r2 = widgets.IntText(value=0, min=0, max=50, description=\"Reward 2\")):\n",
    "  try:\n",
    "    env2 = MultiRewardCC(40, [r1, r2], reward_time=10)\n",
    "    V_guess, _ = td_learner(env2, n_trials, gamma=0.98, alpha=0.001)\n",
    "    fig, ax = plt.subplots()\n",
    "    m, l, _ = ax.stem(V_multi, linefmt='y-', markerfmt='yo', basefmt=' ', label=\"Target\",\n",
    "            use_line_collection=True)\n",
    "    m.set_markersize(15)\n",
    "    m.set_markerfacecolor('none')\n",
    "    l.set_linewidth(4)\n",
    "    m, _, _ = ax.stem(V_guess, linefmt='r', markerfmt='rx', basefmt=' ', label=\"Guess\",\n",
    "                      use_line_collection=True)\n",
    "    m.set_markersize(15)\n",
    "\n",
    "    ax.set_xlabel(\"State\")\n",
    "    ax.set_ylabel(\"Value\")\n",
    "    ax.set_title(\"Guess V(s)\\n\" + reward_guesser_title_hint(r1, r2))\n",
    "    ax.legend()\n",
    "  except NotImplementedError:\n",
    "    print(\"Please finish Exercise 1 first!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Section 2.1 Examining the TD Error\n",
    "\n",
    "Run the cell below to plot the TD errors from our multi-reward environment. A new feature appears in this plot? What is it? Why does it happen?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 421
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:18:54.216108Z",
     "iopub.status.busy": "2021-05-25T01:18:54.199946Z",
     "iopub.status.idle": "2021-05-25T01:18:54.490254Z",
     "shell.execute_reply": "2021-05-25T01:18:54.490700Z"
    },
    "outputId": "5ce4f44c-351f-4992-bbad-c1b32c6702e7"
   },
   "outputs": [],
   "source": [
    "plot_tde_trace(TDE_multi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:18:54.495420Z",
     "iopub.status.busy": "2021-05-25T01:18:54.494905Z",
     "iopub.status.idle": "2021-05-25T01:18:54.499695Z",
     "shell.execute_reply": "2021-05-25T01:18:54.499181Z"
    }
   },
   "outputs": [],
   "source": [
    "#to_remove explanation\n",
    "\"\"\"\n",
    "The TD trace now takes on negative values because the reward delivered is\n",
    "sometimes larger than the expected reward and sometimes smaller.\n",
    "\"\"\";"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Section 3: TD-learning with probabilistic rewards\n",
    "\n",
    "In this environment, we'll return to delivering a single reward of ten units. However, it will be delivered intermittently: on 20 percent of trials, the CS will be shown but the agent will not receive the usual reward; the remaining 80% will proceed as usual.\n",
    "\n",
    " Run the cell below to simulate. How does this compare with the previous experiment?\n",
    "\n",
    "Earlier in the notebook, we saw that changing $\\alpha$ had little effect on learning in a deterministic environment. What happens if you set it to an large value, like 1, in this noisier scenario? Does it seem like it will _ever_ converge?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 430
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:18:56.159463Z",
     "iopub.status.busy": "2021-05-25T01:18:54.503997Z",
     "iopub.status.idle": "2021-05-25T01:18:59.101294Z",
     "shell.execute_reply": "2021-05-25T01:18:59.100819Z"
    },
    "outputId": "cb317e95-80a0-4805-9b7d-b13d33e7a78d"
   },
   "outputs": [],
   "source": [
    "np.random.set_state(rng_state) # Resynchronize everyone's notebooks\n",
    "n_trials = 20000\n",
    "try:\n",
    "  env = ProbabilisticCC(n_steps=40, reward_magnitude=10, reward_time=10,\n",
    "                        p_reward=0.8)\n",
    "  V_stochastic, TDE_stochastic = td_learner(env, n_trials*2, alpha=1)\n",
    "  learning_summary_plot(V_stochastic, TDE_stochastic)\n",
    "except NotImplementedError:\n",
    "  print(\"Please finish Exercise 1 first\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:18:59.105677Z",
     "iopub.status.busy": "2021-05-25T01:18:59.105112Z",
     "iopub.status.idle": "2021-05-25T01:18:59.110419Z",
     "shell.execute_reply": "2021-05-25T01:18:59.109936Z"
    }
   },
   "outputs": [],
   "source": [
    "#to_remove explanation\n",
    "\"\"\"\n",
    "The multi-reward and probabilistic reward enviroments are the same. You\n",
    "could simulate a probabilistic reward of 10 units, delivered 50% of the time,\n",
    "by having a mixture of 10 and 0 unit rewards, or vice versa. The takehome message\n",
    "from these last three exercises is that the *average* or expected reward is\n",
    "what matters during TD learning.\n",
    "\n",
    "Large values of alpha prevent the TD Learner from converging. As a result, the\n",
    "value function seems implausible: one state may have an extremely low value while\n",
    "the neighboring ones remain high. This pattern persists even if training continues\n",
    "for hundreds of thousands of trials.\n",
    "\"\"\";"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Summary\n",
    "\n",
    "In this notebook, we have developed a simple TD Learner and examined how its state representations and reward prediction errors evolve during training. By manipualting its environment and parameters ($\\alpha$, $\\gamma$), you developed an intuition for how it behaves. \n",
    "\n",
    "This simple model closely resembles the behavior of subjects undergoing classical conditioning tasks and the dopamine neurons that may underlie that behavior. You may have implemented TD-reset or used the model to recreate a common experimental error. The update rule used here has been extensively studied for [more than 70 years](https://www.pnas.org/content/108/Supplement_3/15647) as a possible explanation for artificial and biological learning. \n",
    "\n",
    "However, you may have noticed that something is missing from this notebook. We carefully calculated the value of each state, but did not use it to actually do anything. Using values to plan _**Actions**_ is coming up next!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "# Bonus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Exercise 2: Removing the CS\n",
    "\n",
    "In Exercise 1, you (should have) included a term that depends on the conditioned stimulus. Remove it and see what happens. Do you understand why?\n",
    "This phenomena often fools people attempting to train animals--beware!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:18:59.114708Z",
     "iopub.status.busy": "2021-05-25T01:18:59.114156Z",
     "iopub.status.idle": "2021-05-25T01:18:59.119431Z",
     "shell.execute_reply": "2021-05-25T01:18:59.119003Z"
    }
   },
   "outputs": [],
   "source": [
    "#to_remove explanation\n",
    "\n",
    "\"\"\"\n",
    "You should only be updating the V[state] once the conditioned stimulus appears.\n",
    "\n",
    "If you remove this term the Value Function becomes periodic, dropping towards zero\n",
    "right after the reward and gradually rising towards the end of the trial. This\n",
    "behavior is actually correct, because the model is learning the time until the\n",
    "*next* reward, and State 37 is closer to a reward than State 21 or 22.\n",
    "\n",
    "In an actual experiment, the animal often just wants rewards; it doesn't care about\n",
    "/your/ experiment or trial structure!\n",
    "\"\"\";"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "W3D4_Tutorial1",
   "provenance": [],
   "toc_visible": true
  },
  "kernel": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
